package algorthm.systemTraning.unionSearch;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;

/**
 * 并查集: 假设有 N 个集合，要实现两个操作，一个是 union 将包含 A 和 B 两个元素的集合合并。
 * 一个是 isSameSet A 和 B 两个元素是否在同一个集合里面
 */
public class UnionSearch<V> {
    /**
     * 返回 a 和 b 是否在一个集合里面
     */
    public boolean isSameSet(V a , V b){
        return false;
    }
    public static class Node<V>{
        public V value ;
        public Node(V v){
            value = v ;
        }

        @Override
        public boolean equals(Object obj) {
            return this.value.equals(obj) ;
        }
    }
    public static class UnionSet<V>{
        // 设置 V 和 Node 的对应关系。
        private Map<V,Node<V>> nodes = new HashMap<>();
        // 保存 Node 和 Node 直接的父子关系
        private Map<Node<V> , Node<V>> parents = new HashMap<>();
        // 保存每个大集合的元素大小
        private Map<Node<V> , Integer> sizeMap = new HashMap<>();
        public UnionSet(List<V> values){
            for (V value : values) {
                Node node = new Node(value);
                nodes.put(value , node);
                parents.put(node , node);
                sizeMap.put(node , 1);
            }
        }
        public Node<V> getFather(V v){
            Stack<Node<V>> stack = new Stack<>();
            Node<V> curr =  nodes.get(v);
            while(!parents.get(curr).equals(curr)){
                curr = nodes.get(v);
            }
            // 这里是做了优化，举个例子，本来是  A -> B -> C -> D
            // 这样的结果，下面的优化后，就变成了 A -> D , B -> D , C -> D 这样的结构。
            // 这样优化后，再去查询的时候，就能做到一步到位了。gi
            Node<V> head = stack.pop();
            while(!stack.empty()){
                parents.put(stack.pop() , head);
            }
            return curr ;
        }
        public boolean isSameSet(V a , V b){
            return getFather(a).equals(getFather(b));
        }
        public void union(V a , V b){
            Node<V> A = nodes.get(a);
            Node<V> B = nodes.get(b);
            if(!A.equals(B)){
                int aSize = sizeMap.get(A);
                int bSize = sizeMap.get(B);
                Node<V> big = (aSize > bSize ? A : B);
                Node<V> small = (aSize <= bSize ? A : B);
                parents.put(small , big);
                sizeMap.put(big , aSize + bSize);
            }
        }
    }
}
